#!/usr/bin/env python3
"""
Violence Perception Index (VPI) - Fine-tuning Spanish Transformer Models
=========================================================================

This script fine-tunes a Spanish RoBERTa model (RoBERTa-BNE) on LLM-labeled
YouTube comments to classify violence-related discourse.

Two models are trained:
1. Binary classifier: Does the comment discuss violence? (yes/no)
2. Regression model: Violence intensity score (0-10)

Ground truth labels are created via majority voting (binary) and averaging
(regression) across 4 LLM annotators.

Usage:
    python train_violence_classifier.py --data_dir ./multi_model_results --output_dir ./trained_models

Requirements:
    pip install transformers datasets scikit-learn pandas torch accelerate

Author: Francesco Bailo
Date: January 2026
"""

import os
import glob
import argparse
import pandas as pd
import numpy as np
from pathlib import Path
from datetime import datetime

import torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    cohen_kappa_score, mean_squared_error, mean_absolute_error
)

from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    EarlyStoppingCallback
)
from datasets import Dataset

# =============================================================================
# CONFIGURATION
# =============================================================================

# Spanish RoBERTa model from Barcelona Supercomputing Center / Plan TL
MODEL_NAME = "dccuchile/bert-base-spanish-wwm-cased"

# Alternative Spanish models (uncomment to use):
# MODEL_NAME = "dccuchile/bert-base-spanish-wwm-cased"  # BETO
# MODEL_NAME = "BSC-LT/roberta-base-ca-v2"  # MarIA

# Column mappings for LLM outputs
LLM_BINARY_COLS = [
    'discusses_violence_qwen_qwen3_vl_4b',
    'discusses_violence_ibm_granite_3.2_8b',
    'discusses_violence_google_gemma_3n_e4b',
    'discusses_violence_meta_llama_3_8b_instruct'
]

LLM_SCORE_COLS = [
    'violence_score_qwen_qwen3_vl_4b',
    'violence_score_ibm_granite_3.2_8b',
    'violence_score_google_gemma_3n_e4b',
    'violence_score_meta_llama_3_8b_instruct'
]

TEXT_COL = 'text_original'
ID_COL = 'comment_id'

# =============================================================================
# DATA LOADING AND PREPROCESSING
# =============================================================================

def load_all_csvs(data_dir: str) -> pd.DataFrame:
    """
    Load all CSV files matching the pattern from the data directory.
    
    Args:
        data_dir: Path to directory containing CSV files
        
    Returns:
        Combined DataFrame with all comments
    """
    pattern = os.path.join(data_dir, "*COMPARISON*.csv")
    csv_files = glob.glob(pattern)
    
    if not csv_files:
        # Try alternative pattern
        pattern = os.path.join(data_dir, "*.csv")
        csv_files = glob.glob(pattern)
    
    if not csv_files:
        raise FileNotFoundError(f"No CSV files found in {data_dir}")
    
    print(f"Found {len(csv_files)} CSV files")
    
    dfs = []
    for f in sorted(csv_files):
        df = pd.read_csv(f)
        dfs.append(df)
        print(f"  Loaded {f}: {len(df)} rows")
    
    combined = pd.concat(dfs, ignore_index=True)
    print(f"Total comments loaded: {len(combined)}")
    
    return combined


def create_consensus_labels(df: pd.DataFrame, majority_threshold: int = 2) -> pd.DataFrame:
    """
    Create ground truth labels from LLM consensus.
    
    Binary: Majority vote (default: ≥2 out of 4 agree)
    Regression: Average of valid scores
    
    Args:
        df: DataFrame with LLM annotations
        majority_threshold: Minimum LLMs agreeing for positive label
        
    Returns:
        DataFrame with added consensus columns
    """
    df = df.copy()
    
    # --- Binary consensus via majority vote ---
    # Convert to boolean, handling various formats
    for col in LLM_BINARY_COLS:
        if col in df.columns:
            df[col] = df[col].apply(lambda x: 
                True if str(x).lower() in ['true', '1', 'yes', 't'] else False
            )
    
    # Count votes
    binary_cols_present = [c for c in LLM_BINARY_COLS if c in df.columns]
    df['llm_votes'] = df[binary_cols_present].sum(axis=1)
    df['label_binary'] = (df['llm_votes'] >= majority_threshold).astype(int)
    
    # --- Regression consensus via averaging ---
    score_cols_present = [c for c in LLM_SCORE_COLS if c in df.columns]
    
    # Convert scores to numeric, coercing errors to NaN
    for col in score_cols_present:
        df[col] = pd.to_numeric(df[col], errors='coerce')
    
    # Average available scores (ignoring NaN)
    df['label_regression'] = df[score_cols_present].mean(axis=1)
    
    # Clip to 0-10 range
    df['label_regression'] = df['label_regression'].clip(0, 10)
    
    # --- Summary statistics ---
    print("\n=== Label Distribution ===")
    print(f"Binary - Positive (violence): {df['label_binary'].sum()} ({df['label_binary'].mean()*100:.1f}%)")
    print(f"Binary - Negative (no violence): {(1-df['label_binary']).sum()} ({(1-df['label_binary'].mean())*100:.1f}%)")
    print(f"\nRegression - Mean score: {df['label_regression'].mean():.2f}")
    print(f"Regression - Std: {df['label_regression'].std():.2f}")
    print(f"Regression - Min/Max: {df['label_regression'].min():.2f} / {df['label_regression'].max():.2f}")
    
    # LLM agreement statistics
    print(f"\nLLM Agreement Distribution:")
    print(df['llm_votes'].value_counts().sort_index())
    
    return df


def prepare_datasets(df: pd.DataFrame, tokenizer, test_size: float = 0.15, val_size: float = 0.15):
    """
    Prepare train/validation/test datasets.
    
    Args:
        df: DataFrame with text and labels
        tokenizer: HuggingFace tokenizer
        test_size: Fraction for test set
        val_size: Fraction for validation set (from remaining after test)
        
    Returns:
        Tuple of (train_dataset, val_dataset, test_dataset, test_df)
    """
    # Remove rows with missing text or labels
    df_clean = df.dropna(subset=[TEXT_COL, 'label_binary', 'label_regression'])
    print(f"\nAfter removing missing values: {len(df_clean)} comments")
    
    # Stratified split to maintain class balance
    train_val_df, test_df = train_test_split(
        df_clean, 
        test_size=test_size, 
        stratify=df_clean['label_binary'],
        random_state=42
    )
    
    train_df, val_df = train_test_split(
        train_val_df,
        test_size=val_size / (1 - test_size),  # Adjust for remaining data
        stratify=train_val_df['label_binary'],
        random_state=42
    )
    
    print(f"Train: {len(train_df)}, Validation: {len(val_df)}, Test: {len(test_df)}")
    
    def tokenize_data(examples):
        return tokenizer(
            examples['text'],
            padding='max_length',
            truncation=True,
            max_length=128  # Most YouTube comments are short
        )
    
    def create_hf_dataset(subset_df, label_col):
        """Create HuggingFace Dataset from pandas DataFrame."""
        data = {
            'text': subset_df[TEXT_COL].tolist(),
            'labels': subset_df[label_col].tolist(),
            'comment_id': subset_df[ID_COL].tolist() if ID_COL in subset_df.columns else list(range(len(subset_df)))
        }
        dataset = Dataset.from_dict(data)
        dataset = dataset.map(tokenize_data, batched=True)
        return dataset
    
    return train_df, val_df, test_df, create_hf_dataset


# =============================================================================
# MODEL TRAINING
# =============================================================================

def compute_metrics_binary(eval_pred):
    """Compute metrics for binary classification."""
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average='binary', zero_division=0
    )
    accuracy = accuracy_score(labels, predictions)
    kappa = cohen_kappa_score(labels, predictions)
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'cohen_kappa': kappa
    }


def compute_metrics_regression(eval_pred):
    """Compute metrics for regression."""
    predictions, labels = eval_pred
    predictions = predictions.flatten()
    
    mse = mean_squared_error(labels, predictions)
    mae = mean_absolute_error(labels, predictions)
    rmse = np.sqrt(mse)
    
    # Correlation
    correlation = np.corrcoef(predictions, labels)[0, 1]
    
    return {
        'mse': mse,
        'rmse': rmse,
        'mae': mae,
        'correlation': correlation
    }


def train_binary_classifier(
    train_dataset, 
    val_dataset, 
    output_dir: str,
    num_epochs: int = 10,
    batch_size: int = 16,
    learning_rate: float = 2e-5
):
    """
    Train binary classification model.
    
    Args:
        train_dataset: Training dataset
        val_dataset: Validation dataset
        output_dir: Directory to save model
        num_epochs: Number of training epochs
        batch_size: Batch size for training
        learning_rate: Learning rate
        
    Returns:
        Trained model and tokenizer
    """
    print("\n" + "="*60)
    print("TRAINING BINARY CLASSIFIER")
    print("="*60)
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME,
        num_labels=2,
        problem_type="single_label_classification"
    )
    
    training_args = TrainingArguments(
        output_dir=os.path.join(output_dir, "binary_classifier"),
        num_train_epochs=num_epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size * 2,
        warmup_ratio=0.1,
        weight_decay=0.01,
        learning_rate=learning_rate,
        logging_dir=os.path.join(output_dir, "logs_binary"),
        logging_steps=10,
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        greater_is_better=True,
        save_total_limit=2,
        report_to="none",  # Disable wandb/tensorboard
        fp16=torch.cuda.is_available(),  # Use mixed precision if GPU available
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics_binary,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )
    
    print(f"\nDevice: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    trainer.train()
    
    # Save final model
    final_path = os.path.join(output_dir, "binary_classifier_final")
    trainer.save_model(final_path)
    tokenizer.save_pretrained(final_path)
    
    print(f"\nBinary classifier saved to: {final_path}")
    
    return trainer, tokenizer


def train_regression_model(
    train_dataset,
    val_dataset,
    output_dir: str,
    num_epochs: int = 10,
    batch_size: int = 16,
    learning_rate: float = 2e-5
):
    """
    Train regression model for violence intensity scores.
    
    Args:
        train_dataset: Training dataset
        val_dataset: Validation dataset  
        output_dir: Directory to save model
        num_epochs: Number of training epochs
        batch_size: Batch size for training
        learning_rate: Learning rate
        
    Returns:
        Trained model and tokenizer
    """
    print("\n" + "="*60)
    print("TRAINING REGRESSION MODEL")
    print("="*60)
    
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_NAME,
        num_labels=1,
        problem_type="regression"
    )
    
    training_args = TrainingArguments(
        output_dir=os.path.join(output_dir, "regression_model"),
        num_train_epochs=num_epochs,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size * 2,
        warmup_ratio=0.1,
        weight_decay=0.01,
        learning_rate=learning_rate,
        logging_dir=os.path.join(output_dir, "logs_regression"),
        logging_steps=10,
        eval_strategy="epoch",
        save_strategy="epoch",
        load_best_model_at_end=True,
        metric_for_best_model="rmse",
        greater_is_better=False,
        save_total_limit=2,
        report_to="none",
        fp16=torch.cuda.is_available(),
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics_regression,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
    )
    
    print(f"\nDevice: {'CUDA' if torch.cuda.is_available() else 'CPU'}")
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    trainer.train()
    
    # Save final model
    final_path = os.path.join(output_dir, "regression_model_final")
    trainer.save_model(final_path)
    tokenizer.save_pretrained(final_path)
    
    print(f"\nRegression model saved to: {final_path}")
    
    return trainer, tokenizer


# =============================================================================
# EVALUATION
# =============================================================================

def evaluate_on_test_set(trainer, test_dataset, model_type: str):
    """
    Evaluate model on held-out test set.
    
    Args:
        trainer: Trained Trainer object
        test_dataset: Test dataset
        model_type: 'binary' or 'regression'
        
    Returns:
        Dictionary of evaluation metrics
    """
    print(f"\n=== Test Set Evaluation ({model_type}) ===")
    
    results = trainer.evaluate(test_dataset)
    
    for key, value in results.items():
        if isinstance(value, float):
            print(f"  {key}: {value:.4f}")
        else:
            print(f"  {key}: {value}")
    
    return results


def compare_with_dictionary_baseline(test_df, predictions_binary, predictions_regression):
    """
    Compare fine-tuned model with original dictionary approach.
    
    This function can be extended once the dictionary scores are available.
    """
    # Placeholder for comparison with dictionary baseline
    # You would add the scalar_score column from your original data here
    pass


# =============================================================================
# MAIN EXECUTION
# =============================================================================

def main(args):
    """Main training pipeline."""
    
    print("\n" + "="*60)
    print("VPI FINE-TUNING PIPELINE")
    print(f"Model: {MODEL_NAME}")
    print(f"Started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print("="*60)
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Load data
    print("\n--- Loading Data ---")
    df = load_all_csvs(args.data_dir)
    
    # Create consensus labels
    print("\n--- Creating Consensus Labels ---")
    df = create_consensus_labels(df, majority_threshold=args.majority_threshold)
    
    # Save processed data
    processed_path = os.path.join(args.output_dir, "processed_training_data.csv")
    df.to_csv(processed_path, index=False)
    print(f"\nProcessed data saved to: {processed_path}")
    
    # Initialize tokenizer
    print("\n--- Initializing Tokenizer ---")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    
    # Prepare datasets
    print("\n--- Preparing Datasets ---")
    train_df, val_df, test_df, create_hf_dataset = prepare_datasets(
        df, tokenizer, 
        test_size=args.test_size, 
        val_size=args.val_size
    )
    
    # Create datasets for binary classification
    train_dataset_binary = create_hf_dataset(train_df, 'label_binary')
    val_dataset_binary = create_hf_dataset(val_df, 'label_binary')
    test_dataset_binary = create_hf_dataset(test_df, 'label_binary')
    
    # Create datasets for regression
    train_dataset_reg = create_hf_dataset(train_df, 'label_regression')
    val_dataset_reg = create_hf_dataset(val_df, 'label_regression')
    test_dataset_reg = create_hf_dataset(test_df, 'label_regression')
    
    # Train binary classifier
    if not args.skip_binary:
        binary_trainer, _ = train_binary_classifier(
            train_dataset_binary,
            val_dataset_binary,
            args.output_dir,
            num_epochs=args.epochs,
            batch_size=args.batch_size,
            learning_rate=args.learning_rate
        )
        
        # Evaluate on test set
        binary_results = evaluate_on_test_set(
            binary_trainer, test_dataset_binary, 'binary'
        )
    
    # Train regression model
    if not args.skip_regression:
        regression_trainer, _ = train_regression_model(
            train_dataset_reg,
            val_dataset_reg,
            args.output_dir,
            num_epochs=args.epochs,
            batch_size=args.batch_size,
            learning_rate=args.learning_rate
        )
        
        # Evaluate on test set
        regression_results = evaluate_on_test_set(
            regression_trainer, test_dataset_reg, 'regression'
        )
    
    # Save test set with predictions for analysis
    test_df.to_csv(os.path.join(args.output_dir, "test_set.csv"), index=False)
    
    print("\n" + "="*60)
    print("TRAINING COMPLETE")
    print(f"Finished: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    print(f"Models saved to: {args.output_dir}")
    print("="*60)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Fine-tune Spanish transformer for violence classification"
    )
    
    parser.add_argument(
        "--data_dir", 
        type=str, 
        required=True,
        help="Directory containing LLM-labeled CSV files"
    )
    
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./trained_models",
        help="Directory to save trained models"
    )
    
    parser.add_argument(
        "--epochs",
        type=int,
        default=10,
        help="Number of training epochs"
    )
    
    parser.add_argument(
        "--batch_size",
        type=int,
        default=16,
        help="Training batch size"
    )
    
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=2e-5,
        help="Learning rate"
    )
    
    parser.add_argument(
        "--test_size",
        type=float,
        default=0.15,
        help="Fraction of data for test set"
    )
    
    parser.add_argument(
        "--val_size",
        type=float,
        default=0.15,
        help="Fraction of data for validation set"
    )
    
    parser.add_argument(
        "--majority_threshold",
        type=int,
        default=2,
        help="Minimum LLMs agreeing for positive label (2-4)"
    )
    
    parser.add_argument(
        "--skip_binary",
        action="store_true",
        help="Skip binary classifier training"
    )
    
    parser.add_argument(
        "--skip_regression",
        action="store_true",
        help="Skip regression model training"
    )
    
    args = parser.parse_args()
    main(args)
